# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from torchrl.objectives import DistributionalDQNLoss, DQNLoss, HardUpdate, SoftUpdate
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater


def make_target_updater(
    cfg: DictConfig, loss_module: LossModule  # noqa: F821
) -> TargetNetUpdater | None:
    """Builds a target network weight update object."""
    if cfg.loss == "double":
        if not cfg.hard_update:
            target_net_updater = SoftUpdate(
                loss_module, eps=1 - 1 / cfg.value_network_update_interval
            )
        else:
            target_net_updater = HardUpdate(
                loss_module,
                value_network_update_interval=cfg.value_network_update_interval,
            )
    else:
        if cfg.hard_update:
            raise RuntimeError(
                "hard/soft-update are supposed to be used with double SAC loss. "
                "Consider using --loss=double or discarding the hard_update flag."
            )
        target_net_updater = None
    return target_net_updater


def make_dqn_loss(model, cfg) -> tuple[DQNLoss, TargetNetUpdater | None]:
    """Builds the DQN loss module."""
    loss_kwargs = {}
    if cfg.distributional:
        loss_class = DistributionalDQNLoss
    else:
        loss_kwargs.update({"loss_function": cfg.loss_function})
        loss_class = DQNLoss
    if cfg.loss not in ("single", "double"):
        raise NotImplementedError
    loss_kwargs.update({"delay_value": cfg.loss == "double"})
    loss_module = loss_class(model, **loss_kwargs)
    loss_module.make_value_estimator(gamma=cfg.gamma)
    target_net_updater = make_target_updater(cfg, loss_module)
    return loss_module, target_net_updater


@dataclass
class LossConfig:
    """Generic Loss config struct."""

    loss: str = "double"
    # whether double or single SAC loss should be used. Default=double
    hard_update: bool = False
    # whether soft-update should be used with double SAC loss (default) or hard updates.
    loss_function: str = "smooth_l1"
    # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
    value_network_update_interval: int = 1000
    # how often the target value network weights are updated (in number of updates).
    # If soft-updates are used, the value is translated into a moving average decay by using
    # the formula decay=1-1/cfg.value_network_update_interval. Default=1000
    gamma: float = 0.99
    # Decay factor for return computation. Default=0.99.
    num_q_values: int = 2
    # As suggested in the original SAC paper and in https://arxiv.org/abs/1802.09477, we can
    # use two (or more!) different qvalue networks trained independently and choose the lowest value
    # predicted to predict the state action value. This can be disabled by using this flag.
    # REDQ uses an arbitrary number of Q-value functions to speed up learning in MF contexts.
    target_entropy: Any = None
    # Target entropy for the policy distribution. Default is None (auto calculated as the `target_entropy = -action_dim`)


@dataclass
class A2CLossConfig:
    """A2C Loss config struct."""

    gamma: float = 0.99
    # Decay factor for return computation. Default=0.99.
    entropy_coeff: float = 1e-3
    # Entropy factor for the A2C loss
    critic_coeff: float = 1.0
    # Critic factor for the A2C loss
    critic_loss_function: str = "smooth_l1"
    # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).


@dataclass
class PPOLossConfig:
    """PPO Loss config struct."""

    loss: str = "clip"
    # PPO loss class, either clip or kl or base/<empty>. Default=clip

    # PPOLoss base parameters:
    gamma: float = 0.99
    # Decay factor for return computation. Default=0.99.
    lmbda: float = 0.95
    # lambda factor in GAE (using 'lambda' as attribute is prohibited in python, hence the misspelling)
    entropy_bonus: bool = True
    # whether to add an entropy term to the PPO loss.
    entropy_coeff: float = 1e-3
    # Entropy factor for the PPO loss
    samples_mc_entropy: int = 1
    # Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula.
    loss_function: str = "smooth_l1"
    # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
    critic_coeff: float = 1.0
    # Critic loss multiplier when computing the total loss.

    # ClipPPOLoss parameters:
    clip_epsilon: float = 0.2
    # weight clipping threshold in the clipped PPO loss equation.

    # KLPENPPOLoss parameters:
    dtarg: float = 0.01
    # target KL divergence.
    beta: float = 1.0
    # initial KL divergence multiplier.
    increment: float = 2
    # how much beta should be incremented if KL > dtarg. Valid range: increment >= 1.0
    decrement: float = 0.5
    # how much beta should be decremented if KL < dtarg. Valid range: decrement <= 1.0
    samples_mc_kl: int = 1
    # Number of samples to use for a Monte-Carlo estimate of KL if necessary
